import absl.flags

# import d4rl  # type: ignore
import gym
import numpy as np
# import seaborn_image as isns

# from data import Dataset, RandSampler, SlidingWindowSampler, RLUPDataset, DM2Gym
from utilities.replay_buffer import get_d4rl_dataset
from utilities.traj_dataset import get_traj_dataset, compute_returns
from utilities.sampler import TrajSampler
from utilities.utils import (
  # SOTALogger,
  WandBLogger,
  define_flags_with_default,
  get_user_flags,
)
from pathlib import Path

import matplotlib.pyplot as plt

def vis(env, bins, m, n, i):
  eval_sampler = TrajSampler(gym.make(env).unwrapped, 1000)

  # Build dataset and sampler
  dataset = get_d4rl_dataset(eval_sampler.env, 1, 0.99,)

  # n-step frame-stack dataset
  obs = dataset['observations']
  act = dataset['actions']
  nstep_rewards = dataset['rewards']

  # traj dataset
  traj_dataset, raw_dataset = get_traj_dataset(eval_sampler.env, sorting=False)
  traj_returns = [compute_returns(traj) for traj in traj_dataset]
  # traj_lens = [len(traj) for traj in traj_dataset]
  # rewards = np.concatenate([[ts[2] for ts in traj] for traj in traj_dataset])

  # rewards = (rewards - rewards.min()) / (rewards.max() - rewards.min())
  # labels = np.floor((rewards * 10)).astype(np.int32)

  # plt.subplot(2,2,1)
  # plt.hist(rewards, bins=100)
  # plt.title('reward dist')

  # plt.subplot(2,2,2)
  # plt.hist(nstep_rewards, bins=100)
  # plt.title('nstep_reward dist')

  # plt.subplot(2,2,3)
  # plt.hist(traj_returns, bins=100)
  # plt.title('traj_return dist')
  
  # plt.subplot(2,2,4)
  # plt.hist(traj_lens)
  # plt.title('traj_len dist')

  # plt.subplot(m, n, i)
  plt.hist(traj_returns, bins=bins)
  plt.title(f'{env}')
  plt.savefig(f'vis_results/new/{env}.jpg')
  plt.close()


def main():
  mujoco_envs = [f'{env}-{level}-v2' for level in [
    # 'medium',
    # 'medium-replay',
    # 'medium-expert',
    'full-replay'
      ] 
  for env in ['halfcheetah', 'hopper', 'walker2d']]
  # antmaze_envs = ['antmaze-umaze-v0', 'antmaze-umaze-diverse-v0', 'antmaze-medium-play-v0',
  #     'antmaze-medium-diverse-v0', 'antmaze-large-play-v0', 'antmaze-large-diverse-v0']
  # adroit_envs =  ['pen-human-v0', 'hammer-human-v0', 'door-human-v0', 'relocate-human-v0', 'pen-cloned-v0',
  #     'hammer-cloned-v0', 'door-cloned-v0', 'relocate-cloned-v0',
  #     'kitchen-complete-v0', 'kitchen-partial-v0', 'kitchen-mixed-v0'
  #     ]
  n = 1
  m = 3
  i = 0
  # for env in adroit_envs:
  #   i += 1
  #   vis(env, 40, m, n, i)

  for env in mujoco_envs:
    i += 1
    vis(env, 100, m, n, i)

  # for env in antmaze_envs:
  #   i += 1
  #   vis(env, 10, m, n, i)

  # plt.savefig(f'vis_results/all_hists.jpg')

if __name__ == '__main__':
  main()